{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "inference_from_saved_model_tf2_colab.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cT5cdSLPX0ui"
      },
      "source": [
        "# Intro to Object Detection Colab\n",
        "\n",
        "Welcome to the object detection colab! This demo will take you through the steps of running an \"out-of-the-box\" detection model in SavedModel format on a collection of images.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vPs64QA1Zdov"
      },
      "source": [
        "Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OBzb04bdNGM8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install -U --pre tensorflow==\"2.2.0\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NgSXyvKSNHIl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "import pathlib\n",
        "\n",
        "# Clone the tensorflow models repository if it doesn't already exist\n",
        "if \"models\" in pathlib.Path.cwd().parts:\n",
        "  while \"models\" in pathlib.Path.cwd().parts:\n",
        "    os.chdir('..')\n",
        "elif not pathlib.Path('models').exists():\n",
        "  !git clone --depth 1 https://github.com/tensorflow/models"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rhpPgW7TNLs6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Install the Object Detection API\n",
        "%%bash\n",
        "cd models/research/\n",
        "protoc object_detection/protos/*.proto --python_out=.\n",
        "cp object_detection/packages/tf2/setup.py .\n",
        "python -m pip install ."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "yn5_uV1HLvaz",
        "colab": {}
      },
      "source": [
        "import io\n",
        "import os\n",
        "import scipy.misc\n",
        "import numpy as np\n",
        "import six\n",
        "import time\n",
        "\n",
        "from six import BytesIO\n",
        "\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "from PIL import Image, ImageDraw, ImageFont\n",
        "\n",
        "import tensorflow as tf\n",
        "from object_detection.utils import visualization_utils as viz_utils\n",
        "\n",
        "%matplotlib inline"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "-y9R0Xllefec",
        "colab": {}
      },
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: a file path (this can be local or on colossus)\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  img_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "  image = Image.open(BytesIO(img_data))\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "# Load the COCO Label Map\n",
        "category_index = {\n",
        "    1: {'id': 1, 'name': 'person'},\n",
        "    2: {'id': 2, 'name': 'bicycle'},\n",
        "    3: {'id': 3, 'name': 'car'},\n",
        "    4: {'id': 4, 'name': 'motorcycle'},\n",
        "    5: {'id': 5, 'name': 'airplane'},\n",
        "    6: {'id': 6, 'name': 'bus'},\n",
        "    7: {'id': 7, 'name': 'train'},\n",
        "    8: {'id': 8, 'name': 'truck'},\n",
        "    9: {'id': 9, 'name': 'boat'},\n",
        "    10: {'id': 10, 'name': 'traffic light'},\n",
        "    11: {'id': 11, 'name': 'fire hydrant'},\n",
        "    13: {'id': 13, 'name': 'stop sign'},\n",
        "    14: {'id': 14, 'name': 'parking meter'},\n",
        "    15: {'id': 15, 'name': 'bench'},\n",
        "    16: {'id': 16, 'name': 'bird'},\n",
        "    17: {'id': 17, 'name': 'cat'},\n",
        "    18: {'id': 18, 'name': 'dog'},\n",
        "    19: {'id': 19, 'name': 'horse'},\n",
        "    20: {'id': 20, 'name': 'sheep'},\n",
        "    21: {'id': 21, 'name': 'cow'},\n",
        "    22: {'id': 22, 'name': 'elephant'},\n",
        "    23: {'id': 23, 'name': 'bear'},\n",
        "    24: {'id': 24, 'name': 'zebra'},\n",
        "    25: {'id': 25, 'name': 'giraffe'},\n",
        "    27: {'id': 27, 'name': 'backpack'},\n",
        "    28: {'id': 28, 'name': 'umbrella'},\n",
        "    31: {'id': 31, 'name': 'handbag'},\n",
        "    32: {'id': 32, 'name': 'tie'},\n",
        "    33: {'id': 33, 'name': 'suitcase'},\n",
        "    34: {'id': 34, 'name': 'frisbee'},\n",
        "    35: {'id': 35, 'name': 'skis'},\n",
        "    36: {'id': 36, 'name': 'snowboard'},\n",
        "    37: {'id': 37, 'name': 'sports ball'},\n",
        "    38: {'id': 38, 'name': 'kite'},\n",
        "    39: {'id': 39, 'name': 'baseball bat'},\n",
        "    40: {'id': 40, 'name': 'baseball glove'},\n",
        "    41: {'id': 41, 'name': 'skateboard'},\n",
        "    42: {'id': 42, 'name': 'surfboard'},\n",
        "    43: {'id': 43, 'name': 'tennis racket'},\n",
        "    44: {'id': 44, 'name': 'bottle'},\n",
        "    46: {'id': 46, 'name': 'wine glass'},\n",
        "    47: {'id': 47, 'name': 'cup'},\n",
        "    48: {'id': 48, 'name': 'fork'},\n",
        "    49: {'id': 49, 'name': 'knife'},\n",
        "    50: {'id': 50, 'name': 'spoon'},\n",
        "    51: {'id': 51, 'name': 'bowl'},\n",
        "    52: {'id': 52, 'name': 'banana'},\n",
        "    53: {'id': 53, 'name': 'apple'},\n",
        "    54: {'id': 54, 'name': 'sandwich'},\n",
        "    55: {'id': 55, 'name': 'orange'},\n",
        "    56: {'id': 56, 'name': 'broccoli'},\n",
        "    57: {'id': 57, 'name': 'carrot'},\n",
        "    58: {'id': 58, 'name': 'hot dog'},\n",
        "    59: {'id': 59, 'name': 'pizza'},\n",
        "    60: {'id': 60, 'name': 'donut'},\n",
        "    61: {'id': 61, 'name': 'cake'},\n",
        "    62: {'id': 62, 'name': 'chair'},\n",
        "    63: {'id': 63, 'name': 'couch'},\n",
        "    64: {'id': 64, 'name': 'potted plant'},\n",
        "    65: {'id': 65, 'name': 'bed'},\n",
        "    67: {'id': 67, 'name': 'dining table'},\n",
        "    70: {'id': 70, 'name': 'toilet'},\n",
        "    72: {'id': 72, 'name': 'tv'},\n",
        "    73: {'id': 73, 'name': 'laptop'},\n",
        "    74: {'id': 74, 'name': 'mouse'},\n",
        "    75: {'id': 75, 'name': 'remote'},\n",
        "    76: {'id': 76, 'name': 'keyboard'},\n",
        "    77: {'id': 77, 'name': 'cell phone'},\n",
        "    78: {'id': 78, 'name': 'microwave'},\n",
        "    79: {'id': 79, 'name': 'oven'},\n",
        "    80: {'id': 80, 'name': 'toaster'},\n",
        "    81: {'id': 81, 'name': 'sink'},\n",
        "    82: {'id': 82, 'name': 'refrigerator'},\n",
        "    84: {'id': 84, 'name': 'book'},\n",
        "    85: {'id': 85, 'name': 'clock'},\n",
        "    86: {'id': 86, 'name': 'vase'},\n",
        "    87: {'id': 87, 'name': 'scissors'},\n",
        "    88: {'id': 88, 'name': 'teddy bear'},\n",
        "    89: {'id': 89, 'name': 'hair drier'},\n",
        "    90: {'id': 90, 'name': 'toothbrush'},\n",
        "}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QwcBC2TlPSwg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Download the saved model and put it into models/research/object_detection/test_data/\n",
        "!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d5_coco17_tpu-32.tar.gz\n",
        "!tar -xf efficientdet_d5_coco17_tpu-32.tar.gz\n",
        "!mv efficientdet_d5_coco17_tpu-32/ models/research/object_detection/test_data/"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Z2p-PmKLYCVU",
        "colab": {}
      },
      "source": [
        "start_time = time.time()\n",
        "tf.keras.backend.clear_session()\n",
        "detect_fn = tf.saved_model.load('models/research/object_detection/test_data/efficientdet_d5_coco17_tpu-32/saved_model/')\n",
        "end_time = time.time()\n",
        "elapsed_time = end_time - start_time\n",
        "print('Elapsed time: ' + str(elapsed_time) + 's')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "vukkhd5-9NSL",
        "colab": {}
      },
      "source": [
        "import time\n",
        "\n",
        "image_dir = 'models/research/object_detection/test_images'\n",
        "\n",
        "elapsed = []\n",
        "for i in range(2):\n",
        "  image_path = os.path.join(image_dir, 'image' + str(i + 1) + '.jpg')\n",
        "  image_np = load_image_into_numpy_array(image_path)\n",
        "  input_tensor = np.expand_dims(image_np, 0)\n",
        "  start_time = time.time()\n",
        "  detections = detect_fn(input_tensor)\n",
        "  end_time = time.time()\n",
        "  elapsed.append(end_time - start_time)\n",
        "\n",
        "  plt.rcParams['figure.figsize'] = [42, 21]\n",
        "  label_id_offset = 1\n",
        "  image_np_with_detections = image_np.copy()\n",
        "  viz_utils.visualize_boxes_and_labels_on_image_array(\n",
        "        image_np_with_detections,\n",
        "        detections['detection_boxes'][0].numpy(),\n",
        "        detections['detection_classes'][0].numpy().astype(np.int32),\n",
        "        detections['detection_scores'][0].numpy(),\n",
        "        category_index,\n",
        "        use_normalized_coordinates=True,\n",
        "        max_boxes_to_draw=200,\n",
        "        min_score_thresh=.40,\n",
        "        agnostic_mode=False)\n",
        "  plt.subplot(2, 1, i+1)\n",
        "  plt.imshow(image_np_with_detections)\n",
        "\n",
        "mean_elapsed = sum(elapsed) / float(len(elapsed))\n",
        "print('Elapsed time: ' + str(mean_elapsed) + ' second per image')"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}